package org.test4j.integration.testng;

import org.test4j.Context;
import org.test4j.integration.ListenerFactory;
import org.test4j.mock.startup.JavaAgentHits;
import org.testng.*;
import org.testng.annotations.Test;

import java.lang.reflect.Method;

import static org.test4j.mock.faking.util.StackTrace.filterStackTrace;

public final class TestNGListener implements IInvokedMethodListener, IExecutionListener {
    static {
        JavaAgentHits.message();
    }

    @Override
    public void beforeInvocation(IInvokedMethod invokedMethod, ITestResult testResult) {
        Context.clearNoMockingZone();
        ITestNGMethod testNGMethod = testResult.getMethod();
        if (!invokedMethod.isTestMethod()) {
            this.beforeConfigurationMethod(testNGMethod, testResult);
        } else if (isTestNgResult(testResult)) {
            Method testMethod = testNGMethod.getConstructorOrMethod().getMethod();
            ListenerFactory.beforeExecute(testNGMethod.getInstance(), testMethod);
        }
    }

    @Override
    public void afterInvocation(IInvokedMethod invokedMethod, ITestResult testResult) {
        if (invokedMethod.isTestMethod()) {
            ITestNGMethod testNGMethod = invokedMethod.getTestMethod();
            try {
                Throwable e = testResult.getThrowable();
                Method testMethod = testNGMethod.getConstructorOrMethod().getMethod();
                ListenerFactory.afterExecute(testNGMethod.getInstance(), testMethod, e);
            } catch (Throwable t) {
                this.setResult(invokedMethod, testResult, t);
            }
        } else {
            this.filterTestResultThrowable(testResult);
        }
    }

    @Override
    public void onExecutionFinish() {
        ListenerFactory.afterAll();
    }

    private void beforeConfigurationMethod(ITestNGMethod method, ITestResult testResult) {
        if (method.isBeforeClassConfiguration()) {
            Class testClass = testResult.getTestClass().getRealClass();
            ListenerFactory.beforeAll(testClass);
        } else if (method.isBeforeMethodConfiguration()) {
            ListenerFactory.beforeMethod(method.getInstance());
        } else if (method.isAfterClassConfiguration()) {
            ListenerFactory.afterMethod();
        } else if (method.isAfterMethodConfiguration()) {
        } else {
            ListenerFactory.afterAll();
        }
    }

    private static void filterTestResultThrowable(ITestResult testResult) {
        ITestNGMethod method = testResult.getMethod();
        if (method.isAfterMethodConfiguration()) {
            filterStackTrace(testResult.getThrowable());
        }
    }

    private static boolean isExpectedException(IInvokedMethod invokedMethod, Throwable thrownByTest) {
        Method testMethod = invokedMethod.getTestMethod().getConstructorOrMethod().getMethod();
        Class[] expectedExceptions = testMethod.getAnnotation(Test.class).expectedExceptions();
        Class<? extends Throwable> thrownExceptionType = thrownByTest.getClass();

        for (Class expectedException : expectedExceptions) {
            if (expectedException.isAssignableFrom(thrownExceptionType)) {
                return true;
            }
        }
        return false;
    }

    private void setResult(IInvokedMethod invokedMethod, ITestResult testResult, Throwable throwable) {
        filterStackTrace(throwable);
        Throwable thrownByTest = testResult.getThrowable();
        filterStackTrace(thrownByTest);
        if (thrownByTest instanceof TestException && isExpectedException(invokedMethod, throwable)) {
            testResult.setThrowable(null);
            testResult.setStatus(ITestResult.SUCCESS);
        } else if (thrownByTest == null || testResult.isSuccess() && throwable != thrownByTest) {
            testResult.setThrowable(throwable);
            testResult.setStatus(ITestResult.FAILURE);
        }
    }

    /**
     * false: Happens when TestNG is running a JUnit test:
     * erroneously returns a org.junit.runner.Description object.
     *
     * @return
     */
    private boolean isTestNgResult(ITestResult testResult) {
        Object testInstance = testResult.getInstance();
        Class testClass = testResult.getTestClass().getRealClass();
        return testInstance != null && testInstance.getClass() == testClass;
    }
}